// Copyright (c) 2023, Arm Limited and Contributors. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
#define PW_LOG_MODULE_NAME "server"
#define PW_LOG_LEVEL       PW_LOG_LEVEL_DEBUG
#include <inttypes.h>
#include <string.h>

#include "iot_socket.h"
#include "pw_rpc/server.h"

#include "iot-socket.pwpb.h"
#include "iot-socket.rpc.pwpb.h"
#include "pw_hdlc/rpc_channel.h"
#include "pw_hdlc/rpc_packets.h"
#include "pw_log/log.h"
#include "pw_rpc/channel.h"
#include "pw_stream/sys_io_stream.h"
#include "pw_string/util.h"

#include <algorithm>
#include <array>

namespace {
using CreateRequest = iotsdk::socket::pw_rpc::CreateRequest::Message;
using BindRequest = iotsdk::socket::pw_rpc::BindRequest::Message;
using ListenRequest = iotsdk::socket::pw_rpc::ListenRequest::Message;
using AcceptRequest = iotsdk::socket::pw_rpc::AcceptRequest::Message;
using ConnectRequest = iotsdk::socket::pw_rpc::ConnectRequest::Message;
using RecvRequest = iotsdk::socket::pw_rpc::RecvRequest::Message;
using RecvFromRequest = iotsdk::socket::pw_rpc::RecvFromRequest::Message;
using SendRequest = iotsdk::socket::pw_rpc::SendRequest::Message;
using SendToRequest = iotsdk::socket::pw_rpc::SendToRequest::Message;
using RecvMsgRequest = iotsdk::socket::pw_rpc::RecvMsgRequest::Message;
using SendMsgRequest = iotsdk::socket::pw_rpc::SendMsgRequest::Message;
using GetSockNameRequest = iotsdk::socket::pw_rpc::GetSockNameRequest::Message;
using GetPeerNameRequest = iotsdk::socket::pw_rpc::GetPeerNameRequest::Message;
using GetOptRequest = iotsdk::socket::pw_rpc::GetOptRequest::Message;
using SetOptRequest = iotsdk::socket::pw_rpc::SetOptRequest::Message;
using CloseRequest = iotsdk::socket::pw_rpc::CloseRequest::Message;
using ShutdownRequest = iotsdk::socket::pw_rpc::ShutdownRequest::Message;
using GetHostByNameRequest = iotsdk::socket::pw_rpc::GetHostByNameRequest::Message;
using SelectRequest = iotsdk::socket::pw_rpc::SelectRequest::Message;
using GenericResponse = iotsdk::socket::pw_rpc::GenericResponse::Message;
using AcceptResponse = iotsdk::socket::pw_rpc::AcceptResponse::Message;
using RecvResponse = iotsdk::socket::pw_rpc::RecvResponse::Message;
using RecvFromResponse = iotsdk::socket::pw_rpc::RecvFromResponse::Message;
using RecvMsgResponse = iotsdk::socket::pw_rpc::RecvMsgResponse::Message;
using GetSockNameResponse = iotsdk::socket::pw_rpc::GetSockNameResponse::Message;
using GetPeerNameResponse = iotsdk::socket::pw_rpc::GetPeerNameResponse::Message;
using GetOptResponse = iotsdk::socket::pw_rpc::GetOptResponse::Message;
using GetHostByNameResponse = iotsdk::socket::pw_rpc::GetHostByNameResponse::Message;
using SelectResponse = iotsdk::socket::pw_rpc::SelectResponse::Message;

template <class T> using ServiceBase = iotsdk::socket::pw_rpc::pw_rpc::pwpb::IotSocketService::Service<T>;

constexpr size_t kBufferSize = 2048;
constexpr size_t kRecvBufferSize = 512;    // set according to proto options file
constexpr size_t kNameSize = 128;          // set according to proto options file
constexpr size_t kControlBufferSize = 512; // set according to proto options file
constexpr size_t kMaxAddressSize = 16;
constexpr size_t kMaxIOVecNum = 10;

pw::stream::SysIoWriter writer;
pw::hdlc::RpcChannelOutput hdlc_channel_output(writer, pw::hdlc::kDefaultRpcAddress, "HDLC output");
std::array<pw::rpc::Channel, 1> channels{pw::rpc::Channel::Create<1>(&hdlc_channel_output)};
pw::rpc::Server server(channels);

struct Service : ServiceBase<Service> {
    pw::Status Create(const CreateRequest &request, GenericResponse &response)
    {
        response.status = iotSocketCreate(request.af, request.type, request.protocol);

        if (response.status < 0) {
            PW_LOG_WARN("iotSocketCreate: %" PRId32, response.status);
        }

        PW_LOG_DEBUG("Socket %" PRId32 " created", response.status);
        return pw::OkStatus();
    }

    pw::Status Bind(const BindRequest &request, GenericResponse &response)
    {
        response.status =
            iotSocketBind(request.socket, (const uint8_t *)request.ip.data(), request.ip_len, request.port);

        if (response.status < 0) {
            PW_LOG_WARN("iotSocketBind: %" PRId32, response.status);
        }

        PW_LOG_DEBUG("Socket %" PRId32 " bound", response.status);
        return pw::OkStatus();
    }

    pw::Status Listen(const ListenRequest &request, GenericResponse &response)
    {
        response.status = iotSocketListen(request.socket, request.backlog);

        if (response.status < 0) {
            PW_LOG_WARN("iotSocketListen: %" PRId32, response.status);
        }

        PW_LOG_DEBUG("Socket %" PRId32 " listening", response.status);
        return pw::OkStatus();
    }

    pw::Status Accept(const AcceptRequest &request, AcceptResponse &response)
    {
        uint16_t port;
        response.ip.resize(response.ip.max_size());
        response.ip_len = sizeof(iot_in6_addr);
        response.status = iotSocketAccept(request.socket, (uint8_t *)response.ip.data(), &response.ip_len, &port);

        if (response.status < 0) {
            PW_LOG_WARN("iotSocketAccept: %" PRId32, response.status);
        }

        response.port = (uint32_t)port;

        PW_LOG_DEBUG("Socket %" PRId32 " accepted connection", response.status);
        return pw::OkStatus();
    }

    pw::Status Connect(const ConnectRequest &request, GenericResponse &response)
    {
        response.status =
            iotSocketConnect(request.socket, (const uint8_t *)request.ip.data(), request.ip_len, request.port);

        if (response.status < 0) {
            PW_LOG_WARN("iotSocketConnect: %" PRId32, response.status);
        }

        PW_LOG_DEBUG("Socket %" PRId32 " connected", response.status);
        return pw::OkStatus();
    }

    pw::Status Recv(const RecvRequest &request, RecvResponse &response)
    {
        if (request.len > kRecvBufferSize) {
            return pw::Status::InvalidArgument();
        }
        response.buf.resize(request.len);

        response.status = iotSocketRecv(request.socket, response.buf.data(), request.len);

        if (response.status < 0) {
            PW_LOG_WARN("iotSocketRecv: %" PRId32, response.status);
        } else {
            response.buf.resize(response.status);
        }

        PW_LOG_DEBUG("Socket %" PRId32 " read %" PRId32 " data", request.socket, response.status);
        return pw::OkStatus();
    }

    pw::Status RecvFrom(const RecvFromRequest &request, RecvFromResponse &response)
    {
        if (request.len > kRecvBufferSize) {
            return pw::Status::InvalidArgument();
        }
        response.buf.resize(request.len);
        response.ip_len = sizeof(iot_in6_addr);
        response.ip.resize(response.ip.max_size());

        uint16_t port;
        response.status = iotSocketRecvFrom(request.socket,
                                            (uint8_t *)response.buf.data(),
                                            request.len,
                                            (uint8_t *)response.ip.data(),
                                            &response.ip_len,
                                            &port);

        if (response.status < 0) {
            PW_LOG_WARN("iotSocketRecvFrom: %" PRId32, response.status);
        } else {
            response.buf.resize(response.status);
        }

        response.port = (uint32_t)port;
        response.ip.resize(response.ip_len);

        PW_LOG_DEBUG("Socket %" PRId32 " read %" PRId32 " data", request.socket, response.status);
        return pw::OkStatus();
    }

    pw::Status Send(const SendRequest &request, GenericResponse &response)
    {
        if (request.len > request.buf.max_size()) {
            return pw::Status::InvalidArgument();
        }

        response.status = iotSocketSend(request.socket, request.buf.data(), request.len);

        if (response.status < 0) {
            PW_LOG_WARN("iotSocketSend: %" PRId32, response.status);
        }

        PW_LOG_DEBUG("Socket %" PRId32 " sent %" PRId32 " data", request.socket, response.status);
        return pw::OkStatus();
    }

    pw::Status SendTo(const SendToRequest &request, GenericResponse &response)
    {
        if (request.len > request.buf.max_size()) {
            return pw::Status::InvalidArgument();
        }

        response.status = iotSocketSendTo(request.socket,
                                          request.buf.data(),
                                          request.len,
                                          (const uint8_t *)request.ip.data(),
                                          request.ip_len,
                                          (uint16_t)request.port);

        if (response.status < 0) {
            PW_LOG_WARN("iotSocketSendTo: %" PRId32, response.status);
        }

        PW_LOG_DEBUG("Socket %" PRId32 " sent %" PRId32 " data", request.socket, response.status);
        return pw::OkStatus();
    }

    pw::Status RecvMsg(const RecvMsgRequest &request, RecvMsgResponse &response)
    {
        pw::Status status = pw::OkStatus();
        std::byte *control_data = (std::byte *)malloc(kControlBufferSize);
        std::byte *data = (std::byte *)malloc(kRecvBufferSize + sizeof(uint32_t));
        if (control_data && data) {
            iot_iovec iov_vector;
            iov_vector.iov_base = data + sizeof(uint32_t);
            iov_vector.iov_len = kRecvBufferSize;

            std::byte address[kMaxAddressSize];
            memset(&address, 0, kMaxAddressSize);

            struct iot_msghdr msg = {0};
            msg.msg_name = &address;
            msg.msg_namelen = kMaxAddressSize;
            msg.msg_iov = &iov_vector;
            msg.msg_iovlen = 1;
            msg.msg_control = control_data;
            msg.msg_controllen = kControlBufferSize;

            response.status = iotSocketRecvMsg(request.socket, &msg, request.flags);

            uint32_t data_len = 0;
            if (response.status < 0) {
                PW_LOG_WARN("iotSocketRecvMsg: %" PRId32, response.status);
            } else {
                data_len = response.status;
            }

            // we have to convert array of pointers (in this case one pointer) into solid memory
            // this is to remain consistent with the SendMsg, length of data, followed by data
            data[0] = (std::byte)((data_len >> 24) & 0xff);
            data[1] = (std::byte)((data_len >> 16) & 0xff);
            data[2] = (std::byte)((data_len >> 8) & 0xff);
            data[3] = (std::byte)((data_len >> 0) & 0xff);
            response.msg_iov.assign(data, data + data_len + sizeof(uint32_t));

            response.msg_name.assign(address, address + msg.msg_namelen);
            response.msg_namelen = msg.msg_namelen;
            response.msg_iovlen = 1;

            response.msg_control.assign(control_data, control_data + msg.msg_controllen);
            response.msg_controllen = msg.msg_controllen;
            response.msg_flags = msg.msg_flags;

            PW_LOG_DEBUG("Socket %" PRId32 " read %" PRId32 " data", request.socket, response.status);
        } else {
            status = pw::Status::Internal();
        }
        free(control_data);
        free(data);
        return status;
    }

    pw::Status SendMsg(const SendMsgRequest &request, GenericResponse &response)
    {
        if (request.msg_iovlen > kMaxIOVecNum) {
            return pw::Status::Internal();
        }

        // we have to convert solid memory into an array of pointers
        iot_iovec vectors[kMaxIOVecNum];
        uint8_t *iov_data = (uint8_t *)request.msg_iov.data();
        for (size_t i = 0; i < request.msg_iovlen; i++) {
            // data is encoded in the array, first uint32_t is the length, then the bytes of data - this repeats
            // data length is uint32_t big endian
            vectors[i].iov_len = (iov_data[0] << 24) | (iov_data[1] << 16) | (iov_data[2] << 8) | (iov_data[3]);
            iov_data += sizeof(uint32_t);
            vectors[i].iov_base = iov_data;
            iov_data += vectors[i].iov_len;
        }

        iot_msghdr msg = {0};
        iot_sockaddr_in_any peer_addr = {0};
        msg.msg_name = &peer_addr;
        uint8_t *port_bytes;

        if (request.ip.size() == sizeof(peer_addr.in6.sin6_addr)) {
            peer_addr.in6.sin6_len = sizeof(iot_sockaddr_in6);
            peer_addr.in6.sin6_family = IOT_SOCKET_AF_INET6;
            port_bytes = (uint8_t *)&peer_addr.in6.sin6_port;
            memcpy(&peer_addr.in6.sin6_addr, request.ip.data(), request.ip.size());
            msg.msg_namelen = sizeof(iot_sockaddr_in6);
        } else {
            peer_addr.in.sin_len = sizeof(iot_sockaddr_in);
            peer_addr.in.sin_family = IOT_SOCKET_AF_INET;
            port_bytes = (uint8_t *)&peer_addr.in.sin_port;
            memcpy(&peer_addr.in.sin_addr, request.ip.data(), request.ip.size());
            msg.msg_namelen = sizeof(iot_sockaddr_in);
        }
        // this is normally handled by the iot_socket but here we are building a binary blob
        // of the socket addr so we need to fix the network order ourselves
        port_bytes[0] = (uint8_t)((request.port & 0xff00) >> 8);
        port_bytes[1] = (uint8_t)(request.port & 0xff);

        msg.msg_iov = vectors;
        msg.msg_iovlen = request.msg_iovlen;
        msg.msg_control = (void *)request.msg_control.data();
        msg.msg_controllen = request.msg_controllen;

        response.status = iotSocketSendMsg(request.socket, &msg, request.flags);

        if (response.status < 0) {
            PW_LOG_WARN("iotSocketSendMsg: %" PRId32, response.status);
        }

        PW_LOG_DEBUG("Socket %" PRId32 " sent %" PRId32 " data", request.socket, response.status);
        return pw::OkStatus();
    }

    pw::Status GetSockName(const GetSockNameRequest &request, GetSockNameResponse &response)
    {
        uint16_t port;
        response.ip.resize(response.ip.max_size());
        response.ip_len = sizeof(iot_in6_addr);
        response.status = iotSocketGetSockName(request.socket, (uint8_t *)response.ip.data(), &response.ip_len, &port);

        if (response.status < 0) {
            PW_LOG_WARN("iotSocketGetSockName: %" PRId32, response.status);
        }

        response.ip.resize(response.ip_len);

        response.port = (uint32_t)port;
        return pw::OkStatus();
    }

    pw::Status GetPeerName(const GetPeerNameRequest &request, GetPeerNameResponse &response)
    {
        uint16_t port;
        response.ip.resize(response.ip.max_size());
        response.ip_len = sizeof(iot_in6_addr);
        response.status = iotSocketGetPeerName(request.socket, (uint8_t *)response.ip.data(), &response.ip_len, &port);

        if (response.status < 0) {
            PW_LOG_WARN("iotSocketGetPeerName: %" PRId32, response.status);
        } else {
            response.ip.resize(response.ip_len);
        }

        response.port = (uint32_t)port;
        return pw::OkStatus();
    }

    pw::Status GetOpt(const GetOptRequest &request, GetOptResponse &response)
    {
        uint32_t val_len = response.val.max_size();
        response.val.resize(val_len);
        response.status = iotSocketGetOpt(request.socket, request.opt_id, response.val.data(), &val_len);

        if (response.status < 0) {
            PW_LOG_WARN("iotSocketGetOpt: %" PRId32, response.status);
            response.val.resize(0);
        } else {
            response.val.resize(val_len);
        }

        return pw::OkStatus();
    }

    pw::Status SetOpt(const SetOptRequest &request, GenericResponse &response)
    {
        response.status = iotSocketSetOpt(request.socket, request.opt_id, request.val.data(), request.val.size());

        if (response.status < 0) {
            PW_LOG_WARN("iotSocketSetOpt: %" PRId32, response.status);
        }

        return pw::OkStatus();
    }

    pw::Status Close(const CloseRequest &request, GenericResponse &response)
    {
        response.status = iotSocketClose(request.socket);

        if (response.status < 0) {
            PW_LOG_WARN("iotSocketClose: %" PRId32, response.status);
        }

        return pw::OkStatus();
    }

    pw::Status Shutdown(const ShutdownRequest &request, GenericResponse &response)
    {
        response.status = iotSocketShutdown(request.socket, request.option);

        if (response.status < 0) {
            PW_LOG_WARN("iotSocketShutdown: %" PRId32, response.status);
        }

        return pw::OkStatus();
    }

    pw::Status GetHostByName(const GetHostByNameRequest &request, GetHostByNameResponse &response)
    {
        if (request.name_len > kNameSize) {
            return pw::Status::InvalidArgument();
        }

        // Create a nul-terminated copy of the name.

        char name[kNameSize + 1];
        memset(name, 0, kNameSize + 1);
        memcpy(name, request.name.data(), request.name_len);

        response.ip.resize(response.ip.max_size());
        response.ip_len = (request.af == IOT_SOCKET_AF_INET) ? sizeof(iot_in_addr) : sizeof(iot_in6_addr);

        response.status = iotSocketGetHostByName(name, request.af, (uint8_t *)response.ip.data(), &response.ip_len);

        if (response.status < 0) {
            PW_LOG_WARN("iotSocketGetHostByName: %" PRId32, response.status);
        }

        return pw::OkStatus();
    }

    pw::Status Select(const SelectRequest &request, SelectResponse &response)
    {
        size_t select_len = iotSocketMaskGetSize();
        uint8_t *read_mask = (uint8_t *)malloc(select_len);
        uint8_t *write_mask = (uint8_t *)malloc(select_len);
        iotSocketMaskZero(read_mask);
        iotSocketMaskZero(write_mask);

        for (size_t i = 0; i < request.read_sockets.size(); i++) {
            iotSocketMaskSet(request.read_sockets[i], read_mask);
        }
        for (size_t i = 0; i < request.write_sockets.size(); i++) {
            iotSocketMaskSet(request.write_sockets[i], write_mask);
        }

        response.status = iotSocketSelect(read_mask, write_mask, NULL, request.timeout_ms);
        if (response.status < 0) {
            PW_LOG_WARN("Select: %" PRId32, response.status);
        }

        for (size_t i = 0; i < select_len; i++) {
            if (iotSocketMaskIsSet(i, read_mask)) {
                response.read_sockets.push_back(i);
            }
            if (iotSocketMaskIsSet(i, write_mask)) {
                response.write_sockets.push_back(i);
            }
        }

        free(read_mask);
        free(write_mask);

        return pw::OkStatus();
    }
} service;
} // namespace

void start_rpc_app()
{
    static std::array<std::byte, kBufferSize> buffer;

    PW_LOG_INFO("Starting server");

    server.RegisterService(service);

    pw::hdlc::ReadAndProcessPackets(server, hdlc_channel_output, buffer);
}
